#!/usr/bin/env python

from __future__ import print_function

from setuptools import setup, find_packages, distutils
from torch.utils.cpp_extension import BuildExtension, CppExtension
import distutils.ccompiler
import distutils.command.clean
import glob
import inspect
import multiprocessing
import multiprocessing.pool
import os
import platform
import re
import shutil
import subprocess
import sys

base_dir = os.path.dirname(os.path.abspath(__file__))
third_party_path = os.path.join(base_dir, 'third_party')


def _get_build_mode():
  for i in range(1, len(sys.argv)):
    if not sys.argv[i].startswith('-'):
      return sys.argv[i]


def _check_env_flag(name, default=''):
  return os.getenv(name, default).upper() in ['ON', '1', 'YES', 'TRUE', 'Y']


def get_git_head_sha(base_dir):
  xla_git_sha = subprocess.check_output(['git', 'rev-parse', 'HEAD'],
                                        cwd=base_dir).decode('ascii').strip()
  if os.path.isdir(os.path.join(base_dir, '..', '.git')):
    torch_git_sha = subprocess.check_output(['git', 'rev-parse', 'HEAD'],
                                            cwd=os.path.join(
                                                base_dir,
                                                '..')).decode('ascii').strip()
  else:
    torch_git_sha = ''
  return xla_git_sha, torch_git_sha


def get_build_version(xla_git_sha):
  version = os.getenv('TORCH_XLA_VERSION', '1.6')
  if _check_env_flag('VERSIONED_XLA_BUILD', default='0'):
    try:
      version += '+' + xla_git_sha[:7]
    except Exception:
      pass
  return version


def create_version_files(base_dir, version, xla_git_sha, torch_git_sha):
  print('Building torch_xla version: {}'.format(version))
  print('XLA Commit ID: {}'.format(xla_git_sha))
  print('PyTorch Commit ID: {}'.format(torch_git_sha))
  py_version_path = os.path.join(base_dir, 'torch_xla', 'version.py')
  with open(py_version_path, 'w') as f:
    f.write('# Autogenerated file, do not edit!\n')
    f.write("__version__ = '{}'\n".format(version))
    f.write("__xla_gitrev__ = '{}'\n".format(xla_git_sha))
    f.write("__torch_gitrev__ = '{}'\n".format(torch_git_sha))

  cpp_version_path = os.path.join(base_dir, 'torch_xla', 'csrc', 'version.cpp')
  with open(cpp_version_path, 'w') as f:
    f.write('// Autogenerated file, do not edit!\n')
    f.write('#include "torch_xla/csrc/version.h"\n\n')
    f.write('namespace torch_xla {\n\n')
    f.write('const char XLA_GITREV[] = {{"{}"}};\n'.format(xla_git_sha))
    f.write('const char TORCH_GITREV[] = {{"{}"}};\n\n'.format(torch_git_sha))
    f.write('}  // namespace torch_xla\n')


def generate_xla_aten_code(base_dir):
  generate_code_cmd = [os.path.join(base_dir, 'scripts', 'generate_code.sh')]
  if subprocess.call(generate_code_cmd) != 0:
    print(
        'Failed to generate ATEN bindings: {}'.format(generate_code_cmd),
        file=sys.stderr)
    sys.exit(1)


def build_extra_libraries(base_dir, build_mode=None):
  build_libs_cmd = [os.path.join(base_dir, 'build_torch_xla_libs.sh')]
  if build_mode is not None:
    build_libs_cmd += [build_mode]
  if subprocess.call(build_libs_cmd) != 0:
    print(
        'Failed to build external libraries: {}'.format(build_libs_cmd),
        file=sys.stderr)
    sys.exit(1)


def generate_protos(base_dir, third_party_path):
  # Application proto files should be in torch_xla/pb/src/ and the generated
  # files will go in torch_xla/pb/cpp/.
  proto_files = glob.glob(os.path.join(base_dir, 'torch_xla/pb/src/*.proto'))
  if proto_files:
    protoc = os.path.join(
        third_party_path,
        'tensorflow/bazel-out/host/bin/external/com_google_protobuf/protoc')
    protoc_cmd = [
        protoc, '-I',
        os.path.join(third_party_path, 'tensorflow'), '-I',
        os.path.join(base_dir, 'torch_xla/pb/src'), '--cpp_out',
        os.path.join(base_dir, 'torch_xla/pb/cpp')
    ] + proto_files
    if subprocess.call(protoc_cmd) != 0:
      print(
          'Failed to generate protobuf files: {}'.format(protoc_cmd),
          file=sys.stderr)
      sys.exit(1)


def _compile_parallel(self,
                      sources,
                      output_dir=None,
                      macros=None,
                      include_dirs=None,
                      debug=0,
                      extra_preargs=None,
                      extra_postargs=None,
                      depends=None):
  # Those lines are copied from distutils.ccompiler.CCompiler directly.
  macros, objects, extra_postargs, pp_opts, build = self._setup_compile(
      output_dir, macros, include_dirs, sources, depends, extra_postargs)
  cc_args = self._get_cc_args(pp_opts, debug, extra_preargs)

  def compile_one(obj):
    try:
      src, ext = build[obj]
    except KeyError:
      return
    self._compile(obj, src, ext, cc_args, extra_postargs, pp_opts)

  list(
      multiprocessing.pool.ThreadPool(multiprocessing.cpu_count()).imap(
          compile_one, objects))
  return objects


# Plant the parallel compile function.
if _check_env_flag('COMPILE_PARALLEL', default='1'):
  try:
    if (inspect.signature(distutils.ccompiler.CCompiler.compile) ==
        inspect.signature(_compile_parallel)):
      distutils.ccompiler.CCompiler.compile = _compile_parallel
  except:
    pass


class Clean(distutils.command.clean.clean):

  def run(self):
    import glob
    import re
    with open('.gitignore', 'r') as f:
      ignores = f.read()
      pat = re.compile(r'^#( BEGIN NOT-CLEAN-FILES )?')
      for wildcard in filter(None, ignores.split('\n')):
        match = pat.match(wildcard)
        if match:
          if match.group(1):
            # Marker is found and stop reading .gitignore.
            break
          # Ignore lines which begin with '#'.
        else:
          for filename in glob.glob(wildcard):
            try:
              os.remove(filename)
            except OSError:
              shutil.rmtree(filename, ignore_errors=True)

    # It's an old-style class in Python 2.7...
    distutils.command.clean.clean.run(self)


class Build(BuildExtension):

  def run(self):
    # Run the original BuildExtension first. We need this before building
    # the tests.
    BuildExtension.run(self)
    if _check_env_flag('BUILD_CPP_TESTS', default='1'):
      # Build the C++ tests.
      cmd = [os.path.join(base_dir, 'test/cpp/run_tests.sh'), '-B']
      if subprocess.call(cmd) != 0:
        print('Failed to build tests: {}'.format(cmd), file=sys.stderr)
        sys.exit(1)


xla_git_sha, torch_git_sha = get_git_head_sha(base_dir)
version = get_build_version(xla_git_sha)

build_mode = _get_build_mode()
if build_mode not in ['clean']:
  # Generate version info (torch_xla.__version__).
  create_version_files(base_dir, version, xla_git_sha, torch_git_sha)

  # Generate the code before globbing!
  generate_xla_aten_code(base_dir)

  # Build the support libraries (ie, TF).
  build_extra_libraries(base_dir, build_mode=build_mode)

  # Generate the proto C++/python files only after third_party has built.
  generate_protos(base_dir, third_party_path)

# Fetch the sources to be built.
torch_xla_sources = (
    glob.glob('torch_xla/csrc/*.cpp') + glob.glob('torch_xla/csrc/ops/*.cpp') +
    glob.glob('torch_xla/pb/cpp/*.cc'))

# Constant known variables used throughout this file.
lib_path = os.path.join(base_dir, 'torch_xla/lib')
pytorch_source_path = os.getenv('PYTORCH_SOURCE_PATH',
                                os.path.dirname(base_dir))

# Setup include directories folders.
include_dirs = [
    base_dir,
]
for ipath in [
    'tensorflow/bazel-tensorflow',
    'tensorflow/bazel-bin',
    'tensorflow/bazel-tensorflow/external/protobuf_archive/src',
    'tensorflow/bazel-tensorflow/external/com_google_protobuf/src',
    'tensorflow/bazel-tensorflow/external/eigen_archive',
    'tensorflow/bazel-tensorflow/external/com_google_absl',
]:
  include_dirs.append(os.path.join(third_party_path, ipath))
include_dirs += [
    pytorch_source_path,
    os.path.join(pytorch_source_path, 'torch/csrc'),
    os.path.join(pytorch_source_path, 'torch/lib/tmp_install/include'),
]

library_dirs = []
library_dirs.append(lib_path)

extra_link_args = []

DEBUG = _check_env_flag('DEBUG')
IS_DARWIN = (platform.system() == 'Darwin')
IS_LINUX = (platform.system() == 'Linux')


def make_relative_rpath(path):
  if IS_DARWIN:
    return '-Wl,-rpath,@loader_path/' + path
  else:
    return '-Wl,-rpath,$ORIGIN/' + path


extra_compile_args = [
    '-std=c++14',
    '-Wno-sign-compare',
    '-Wno-deprecated-declarations',
    '-Wno-return-type',
]

if re.match(r'clang', os.getenv('CC', '')):
  extra_compile_args += [
      '-Wno-macro-redefined',
      '-Wno-return-std-move',
  ]

if DEBUG:
  extra_compile_args += ['-O0', '-g']
  extra_link_args += ['-O0', '-g']
else:
  extra_compile_args += ['-DNDEBUG']

extra_link_args += ['-lxla_computation_client']

setup(
    name='torch_xla',
    version=version,
    description='XLA bridge for PyTorch',
    url='https://github.com/pytorch/xla',
    author='PyTorch/XLA Dev Team',
    author_email='pytorch-xla@googlegroups.com',
    # Exclude the build files.
    packages=find_packages(exclude=['build']),
    ext_modules=[
        CppExtension(
            '_XLAC',
            torch_xla_sources,
            include_dirs=include_dirs,
            extra_compile_args=extra_compile_args,
            library_dirs=library_dirs,
            extra_link_args=extra_link_args + \
                [make_relative_rpath('torch_xla/lib')],
        ),
    ],
    package_data={
        'torch_xla': [
            'lib/*.so*',
        ],
    },
    data_files=[
        'test/cpp/build/test_ptxla',
        'scripts/fixup_binary.py',
    ],
    cmdclass={
        'build_ext': Build,
        'clean': Clean,
    })
